# -*- encoding: utf-8 -*-
import pandas as pd
import sklearn
import math

import requests
import pickle
import numpy as np
import json
import time

from os import path as P
from sklearn.model_selection import train_test_split
from sklearn import metrics

from cooka.common import dataset_util
from cooka.common import util
from cooka.common.model import PartitionClass

from deeptables.models.hyper_dt import HyperDT, DnnModule, DTModuleSpace, DTFit, DTEstimator
from deeptables.utils import consts as DT_consts

from hypernets.core.callbacks import *
from hypernets.core.ops import HyperInput, ModuleChoice
from hypernets.core.search_space import *
from hypernets.core.searcher import OptimizeDirection
from hypernets.searchers.mcts_searcher import MCTSSearcher

from hypergbm.estimators import LightGBMEstimator, XGBoostEstimator, CatBoostEstimator
from hypergbm.hyper_gbm import HyperGBM
from hypergbm.pipeline import DataFrameMapper
from hypergbm.sklearn.sklearn_ops import numeric_pipeline_complex, categorical_pipeline_simple
from hypernets.searchers.random_searcher import RandomSearcher
from hypergbm.search_space import search_space_general
from hypernets.experiment.general import GeneralExperiment
from hypergbm import CompeteExperiment
from tabular_toolbox.column_selector import column_object

{% if target_source_type == 'raw_python' %}
Status_Failed = 'failed'
Status_Succeed = 'succeed'
def train_callback(portal, train_job_name, dataset_name, type, status, took, extension, **kwargs):
    url = f"{portal}/api/dataset/{dataset_name}/feature-series/default/train-job/{train_job_name}"
    req_body_dict = \
        {
            "type": type,
            "status": status,
            "took": took,
            "datetime": round(time.time() * 1000),
            "extension": extension
        }
    req_body = json.dumps(req_body_dict)
    print(f"Send train process event: \n{url}\n{req_body}")
    response = requests.post(url, data=req_body, timeout=20, headers={"Content-Type": "application/json;charset=utf-8"})
    response_dict = json.loads(response.text)
    code = response_dict["code"]
    if code != 0:
        raise Exception(f"Update failed, {response.text}")
    return response_dict['data']


class TrainTrialCallback(Callback):
    def __init__(self, server_portal, train_job_name, dataset_name):
        super(TrainTrialCallback, self).__init__()
        self.server_portal = server_portal
        self.train_job_name = train_job_name
        self.dataset_name = dataset_name

    def on_build_estimator(self, hyper_model, space, estimator, trial_no):
        pass

    def on_trial_begin(self, hyper_model, space, trial_no):
        pass
        #trial_extension = {
        #    "trial_no": trial_no
        #}
        #train_callback(self.server_portal, self.train_job_name, self.dataset_name, 'optimize_start', 'succeed', 0, trial_extension)

    def get_space_params(self, space):
        params_dict = {}
        for hyper_param in space.get_all_params():
            references = list(hyper_param.references)
            if len(references) > 0:
                param_name = hyper_param.alias[len(list(hyper_param.references)[0].name) + 1:]
                param_value = hyper_param.value
                if isinstance(param_value, int) or isinstance(param_value, float):
                    if not isinstance(param_value, bool):
                        params_dict[param_name] = param_value
        return params_dict

    def ensure_number(self, value, var_name):
        if value is None:
             raise ValueError(f"Var {var_name} can not be None.")
        else:
            if not isinstance(value, float) and not isinstance(value, int):
                raise ValueError(f"Var {var_name} = {value} not a number.")


    def on_trial_end(self, hyper_model, space, trial_no, reward, improved, elapsed):
        self.ensure_number(reward, 'reward')
        self.ensure_number(trial_no, 'trial_no')
        self.ensure_number(elapsed, 'elapsed')

        trial_extension = {
            "trial_no": trial_no,
            "status": 'succeed',
            "extension": {
                "reward": reward,
                "elapsed": elapsed,
                "params": self.get_space_params(space)
            }
        }
        train_callback(self.server_portal, self.train_job_name, self.dataset_name, 'optimize', Status_Succeed, elapsed, trial_extension)

    def on_skip_trial(self, hyper_model, space, trial_no, reason, reward, improved, elapsed):
        trial_extension = {
            "trial_no": trial_no,
            "status": 'skip',
            "extension": {
                "reason": str(reason)
            }
        }
        train_callback(self.server_portal, self.train_job_name, self.dataset_name, 'optimize', Status_Succeed, elapsed, trial_extension)
{% else %}
import shap
import matplotlib.pyplot as plt
shap.initjs()
from deeptables.utils.shap import DeepTablesExplainer
{% endif %}
